import json
import uuid
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
                    Optional, Tuple, TypeAlias, TypedDict, Union, cast)

from openai.types.chat import (ChatCompletionContentPartImageParam,
                               ChatCompletionContentPartInputAudioParam)
from openai.types.chat import \
    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
from openai.types.chat import (ChatCompletionContentPartTextParam,
                               ChatCompletionMessageParam)
from transformers import AutoConfig
from typing_extensions import Required

from tensorrt_llm.inputs import (ConversationMessage, MultimodalData,
                                 MultimodalDataTracker,
                                 add_multimodal_placeholders, async_load_audio,
                                 async_load_image, async_load_video)
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
from tensorrt_llm.logger import logger


class VideoURL(TypedDict):
    """Type definition for video URL structure."""
    url: Required[str]


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
    """Type definition for video content part parameters."""
    video_url: Required[VideoURL]
    type: Required[Literal["video_url"]]


# Type Aliases and Constants
ChatCompletionContentPartParam: TypeAlias = Union[
    OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam,
    str]

# TODO: Add "input_audio" to support byte_encoded audio input.
VALID_MESSAGE_CONTENT_MM_PART_TYPES = [
    "text", "image_url", "video_url", "audio_url"
]

# Parser Functions
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
_AudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)

MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
    str, dict[str, str]]]] = {
        "text":
        lambda part: _TextParser(part).get("text", None),
        "image_url":
        lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
        "video_url":
        lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
        "audio_url":
        lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
    }


def _parse_chat_message_content_mm_part(
    part: ChatCompletionContentPartParam
) -> tuple[str, Union[str, dict[str, str]]]:
    """Parse a single multimodal part of a chat message."""
    assert isinstance(part, dict)
    part_type = part.get("type", None)

    if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
        return part_type, MM_PARSER_MAP[part_type](part)

    if not isinstance(part_type, str):
        raise ValueError("Invalid 'type' field in multimodal part.")
    return part_type, "unknown part_type content"


def parse_chat_message_content_part(
    part: ChatCompletionMessageParam,
    mm_data_tracker: MultimodalDataTracker,
) -> Optional[Any]:
    """Parse a single part of a chat message."""
    if isinstance(part, str):
        return part

    part_type, content = _parse_chat_message_content_mm_part(part)

    # if part_type is text/image_url/video_url/audio_url but content is None, log a warning and skip
    if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
        logger.warning(
            "Skipping multimodal part '%s' (type: '%s') with empty / unparsable content.",
            part, part_type)
        return None

    if part_type == "text":
        return cast(str, content)

    if part_type == "image_url":
        str_content = cast(str, content)

        async def load_image_async():
            try:
                image_kwargs = (
                    mm_data_tracker._multimodal_server_config.media_io_kwargs
                    or {}).get("image", {})
                return await async_load_image(str_content, **image_kwargs)
            except Exception as e:
                logger.error(f"Failed to load image: {str(e)}")
                return None

        return MultimodalData(modality="image", data=load_image_async())

    if part_type == "video_url":
        str_content = cast(str, content)

        async def load_video_async():
            try:
                video_kwargs = (
                    mm_data_tracker._multimodal_server_config.media_io_kwargs
                    or {}).get("video", {})
                return await async_load_video(str_content, **video_kwargs)
            except Exception as e:
                logger.error(f"Failed to load video: {str(e)}")
                return None

        return MultimodalData(modality="video", data=load_video_async())

    if part_type == "audio_url":
        str_content = cast(str, content)

        async def load_audio_async():
            try:
                audio_kwargs = (
                    mm_data_tracker._multimodal_server_config.media_io_kwargs
                    or {}).get("audio", {})
                return await async_load_audio(str_content, **audio_kwargs)
            except Exception as e:
                logger.error(f"Failed to load audio: {str(e)}")
                return None

        return MultimodalData(modality="audio", data=load_audio_async())

    raise NotImplementedError(f"Unknown part type: {part_type}")


def parse_chat_message_content_parts(
    role: str,
    parts: Iterable[ChatCompletionMessageParam],
    mm_data_tracker: MultimodalDataTracker,
) -> ConversationMessage:
    """Parse multiple parts of a chat message."""
    text_parts = []
    media_parts = []
    for part in parts:
        parse_res = parse_chat_message_content_part(part, mm_data_tracker)
        if parse_res:
            if isinstance(parse_res, str):
                text_parts.append(parse_res)
            else:
                media_parts.append(parse_res)

    text_prompt = "\n".join(text_parts)

    return ConversationMessage(role=role,
                               content=text_prompt,
                               media=media_parts)


def parse_chat_message_content(
        message: ChatCompletionMessageParam,
        mm_data_tracker: MultimodalDataTracker) -> ConversationMessage:
    """Parse the content of a chat message."""
    role = message["role"]
    content = message.get("content")

    if content is None:
        content = []
    elif isinstance(content, str):
        content = [
            ChatCompletionContentPartTextParam(type="text", text=content)
        ]

    result = parse_chat_message_content_parts(
        role,
        content,
        mm_data_tracker,
    )
    if role == "assistant":
        result.update(_parse_assistant_message_content(message))
    elif role == "tool":
        result.update(_parse_tool_message_content(message))
    return result


# Adapted from: https://github.com/vllm-project/vllm/blob/4574d48bab9c4e38b7c0a830eeefc8f0980e8c58/vllm/entrypoints/chat_utils.py#L1406
def _parse_assistant_message_content(message: Dict[str, Any]) -> Dict[str, Any]:
    result = {}
    tool_calls = message.get("tool_calls")
    if tool_calls is not None:
        result["tool_calls"] = []
        for item in tool_calls:
            if content := item["function"].get("arguments"):
                if isinstance(content, str):
                    item["function"]["arguments"] = json.loads(content)
                else:
                    item["function"]["arguments"] = content
            else:
                item["function"]["arguments"] = {}
            result["tool_calls"].append(item)

    return result


def _parse_tool_message_content(message: Dict[str, Any]) -> Dict[str, Any]:
    result = {}
    if "tool_call_id" in message:
        result["tool_call_id"] = message["tool_call_id"]
    return result


def parse_chat_messages_coroutines(
    messages: List[ChatCompletionMessageParam],
    model_config: AutoConfig,
    multimodal_server_config: Optional[MultimodalServerConfig] = None
) -> Tuple[List[ConversationMessage], Optional[Coroutine[
        Any, Any, Optional[Dict[str, List[Any]]]]]]:
    """Parse multiple chat messages and return conversation and coroutine."""
    conversation = []
    mm_placeholder_counts = []
    mm_data_tracker = MultimodalDataTracker(model_config.model_type,
                                            multimodal_server_config)

    for msg in messages:
        parsed_msg = parse_chat_message_content(msg, mm_data_tracker)
        conversation.append(parsed_msg)
        if parsed_msg["media"]:
            for mdata in parsed_msg["media"]:
                mm_data_tracker.add_data(mdata["modality"], mdata["data"])
        mm_placeholder_count = mm_data_tracker.placeholder_counts()
        if mm_placeholder_count:
            parsed_msg["content"] = add_multimodal_placeholders(
                model_config.model_type, parsed_msg["content"],
                mm_placeholder_count)
        mm_placeholder_counts.append(mm_placeholder_count)

    return conversation, mm_data_tracker.retrieve_all_async(
    ), mm_placeholder_counts


def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
    if id_type == "kimi_k2":
        return f"functions.{func_name}:{idx}"
    else:
        # by default return random
        return f"chatcmpl-tool-{uuid.uuid4().hex}"


# Adapted from
# https://github.com/vllm-project/vllm/blob/44b5ce956d3cf28841615a58c1c0873af87bcfe2/vllm/entrypoints/chat_utils.py
@lru_cache
def load_chat_template(
    chat_template: Path | str | None,
    *,
    is_literal: bool = False,
) -> str | None:
    if chat_template is None:
        return None

    if is_literal:
        if isinstance(chat_template, Path):
            raise TypeError(
                "chat_template is expected to be read directly from its value")

        return chat_template

    try:
        with open(chat_template) as f:
            return f.read()
    except OSError as e:
        if isinstance(chat_template, Path):
            raise

        JINJA_CHARS = "{}\n"
        if not any(c in chat_template for c in JINJA_CHARS):
            msg = (f"The supplied chat template ({chat_template}) "
                   f"looks like a file path, but it failed to be "
                   f"opened. Reason: {e}")
            raise ValueError(msg) from e

        # If opening a file fails, set chat template to be args to
        # ensure we decode so our escape are interpreted correctly
        return load_chat_template(chat_template, is_literal=True)
